import torch
import torch.onnx
import torch.nn as nn
import torch.nn.functional as F
import onnx
import onnxruntime

# 定义模型类（与之前定义的模型类相同）
class GazeNet(nn.Module):
    def __init__(self):
        super(GazeNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.fc1 = nn.Linear(64 * 16 * 16, 128)
        self.fc2 = nn.Linear(128, 6)  # 假设有5个分类

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 64 * 16 * 16)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# 加载训练好的模型
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = GazeNet().to(device)
model.load_state_dict(torch.load('best_model.pth'))
model.eval()

# 输入示例数据
dummy_input = torch.randn(1, 3, 64, 64).to(device)

# 转换为ONNX格式
onnx_model_path = 'best_model.onnx'
torch.onnx.export(
    model,
    dummy_input,
    onnx_model_path,
    input_names=['input'],
    output_names=['output'],
    #dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}},
    opset_version=11
)

print(f"Model has been converted to ONNX and saved at {onnx_model_path}")

